{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sampling and Splitting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sampling\n",
    "\n",
    "Sampling for basic tabular datasets. (Not designed for time series as of now.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import polars as pl\n",
    "import polars_ds as pds\n",
    "import polars_ds.sample_and_split as sa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (5, 8)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>row_num</th><th>uniform_1</th><th>uniform_2</th><th>exp</th><th>normal</th><th>fat_normal</th><th>flags</th><th>category</th></tr><tr><td>i64</td><td>f64</td><td>f64</td><td>f64</td><td>f64</td><td>f64</td><td>i32</td><td>str</td></tr></thead><tbody><tr><td>0</td><td>3.169978</td><td>0.183096</td><td>0.848878</td><td>-0.988939</td><td>369.76195</td><td>2</td><td>&quot;A&quot;</td></tr><tr><td>1</td><td>8.810768</td><td>0.569672</td><td>0.048483</td><td>-0.44255</td><td>258.012662</td><td>0</td><td>&quot;A&quot;</td></tr><tr><td>2</td><td>3.274063</td><td>0.632772</td><td>0.447468</td><td>0.255512</td><td>-1284.389879</td><td>1</td><td>&quot;A&quot;</td></tr><tr><td>3</td><td>10.847672</td><td>0.89006</td><td>0.772062</td><td>0.735149</td><td>-0.362983</td><td>0</td><td>&quot;A&quot;</td></tr><tr><td>4</td><td>11.66482</td><td>0.907167</td><td>1.393929</td><td>2.285448</td><td>-2031.321622</td><td>0</td><td>&quot;A&quot;</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (5, 8)\n",
       "┌─────────┬───────────┬───────────┬──────────┬───────────┬──────────────┬───────┬──────────┐\n",
       "│ row_num ┆ uniform_1 ┆ uniform_2 ┆ exp      ┆ normal    ┆ fat_normal   ┆ flags ┆ category │\n",
       "│ ---     ┆ ---       ┆ ---       ┆ ---      ┆ ---       ┆ ---          ┆ ---   ┆ ---      │\n",
       "│ i64     ┆ f64       ┆ f64       ┆ f64      ┆ f64       ┆ f64          ┆ i32   ┆ str      │\n",
       "╞═════════╪═══════════╪═══════════╪══════════╪═══════════╪══════════════╪═══════╪══════════╡\n",
       "│ 0       ┆ 3.169978  ┆ 0.183096  ┆ 0.848878 ┆ -0.988939 ┆ 369.76195    ┆ 2     ┆ A        │\n",
       "│ 1       ┆ 8.810768  ┆ 0.569672  ┆ 0.048483 ┆ -0.44255  ┆ 258.012662   ┆ 0     ┆ A        │\n",
       "│ 2       ┆ 3.274063  ┆ 0.632772  ┆ 0.447468 ┆ 0.255512  ┆ -1284.389879 ┆ 1     ┆ A        │\n",
       "│ 3       ┆ 10.847672 ┆ 0.89006   ┆ 0.772062 ┆ 0.735149  ┆ -0.362983    ┆ 0     ┆ A        │\n",
       "│ 4       ┆ 11.66482  ┆ 0.907167  ┆ 1.393929 ┆ 2.285448  ┆ -2031.321622 ┆ 0     ┆ A        │\n",
       "└─────────┴───────────┴───────────┴──────────┴───────────┴──────────────┴───────┴──────────┘"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pds.frame(size=100_000).with_columns(\n",
    "    pds.random(0.0, 12.0).alias(\"uniform_1\"),\n",
    "    pds.random(0.0, 1.0).alias(\"uniform_2\"),\n",
    "    pds.random_exp(0.5).alias(\"exp\"),\n",
    "    pds.random_normal(0.0, 1.0).alias(\"normal\"),\n",
    "    pds.random_normal(0.0, 1000.0).alias(\"fat_normal\"),\n",
    "    (pds.random_int(0, 3)).alias(\"flags\"),\n",
    "    pl.Series([\"A\"] * 30_000 + [\"B\"] * 30_000 + [\"C\"] * 40_000).alias(\"category\"),\n",
    ")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['row_num', 'normal', 'flags']"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sa.random_cols(df, 2, keep = [\"row_num\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (60_000, 8)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>row_num</th><th>uniform_1</th><th>uniform_2</th><th>exp</th><th>normal</th><th>fat_normal</th><th>flags</th><th>category</th></tr><tr><td>i64</td><td>f64</td><td>f64</td><td>f64</td><td>f64</td><td>f64</td><td>i32</td><td>str</td></tr></thead><tbody><tr><td>1</td><td>8.810768</td><td>0.569672</td><td>0.048483</td><td>-0.44255</td><td>258.012662</td><td>0</td><td>&quot;A&quot;</td></tr><tr><td>2</td><td>3.274063</td><td>0.632772</td><td>0.447468</td><td>0.255512</td><td>-1284.389879</td><td>1</td><td>&quot;A&quot;</td></tr><tr><td>4</td><td>11.66482</td><td>0.907167</td><td>1.393929</td><td>2.285448</td><td>-2031.321622</td><td>0</td><td>&quot;A&quot;</td></tr><tr><td>6</td><td>1.522247</td><td>0.626331</td><td>0.460844</td><td>-0.060739</td><td>1487.444343</td><td>1</td><td>&quot;A&quot;</td></tr><tr><td>7</td><td>3.93548</td><td>0.363229</td><td>2.002222</td><td>-0.613627</td><td>-335.203183</td><td>0</td><td>&quot;A&quot;</td></tr><tr><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td></tr><tr><td>99991</td><td>3.808594</td><td>0.693914</td><td>6.727779</td><td>-0.781093</td><td>-868.307031</td><td>2</td><td>&quot;C&quot;</td></tr><tr><td>99994</td><td>6.246362</td><td>0.99597</td><td>3.468162</td><td>-0.699768</td><td>-145.471814</td><td>1</td><td>&quot;C&quot;</td></tr><tr><td>99996</td><td>0.520435</td><td>0.758179</td><td>0.680518</td><td>0.788875</td><td>-3203.56896</td><td>2</td><td>&quot;C&quot;</td></tr><tr><td>99997</td><td>6.250958</td><td>0.762393</td><td>0.08691</td><td>1.79754</td><td>696.859327</td><td>1</td><td>&quot;C&quot;</td></tr><tr><td>99998</td><td>4.491091</td><td>0.396969</td><td>0.012585</td><td>2.024051</td><td>-2468.859815</td><td>2</td><td>&quot;C&quot;</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (60_000, 8)\n",
       "┌─────────┬───────────┬───────────┬──────────┬───────────┬──────────────┬───────┬──────────┐\n",
       "│ row_num ┆ uniform_1 ┆ uniform_2 ┆ exp      ┆ normal    ┆ fat_normal   ┆ flags ┆ category │\n",
       "│ ---     ┆ ---       ┆ ---       ┆ ---      ┆ ---       ┆ ---          ┆ ---   ┆ ---      │\n",
       "│ i64     ┆ f64       ┆ f64       ┆ f64      ┆ f64       ┆ f64          ┆ i32   ┆ str      │\n",
       "╞═════════╪═══════════╪═══════════╪══════════╪═══════════╪══════════════╪═══════╪══════════╡\n",
       "│ 1       ┆ 8.810768  ┆ 0.569672  ┆ 0.048483 ┆ -0.44255  ┆ 258.012662   ┆ 0     ┆ A        │\n",
       "│ 2       ┆ 3.274063  ┆ 0.632772  ┆ 0.447468 ┆ 0.255512  ┆ -1284.389879 ┆ 1     ┆ A        │\n",
       "│ 4       ┆ 11.66482  ┆ 0.907167  ┆ 1.393929 ┆ 2.285448  ┆ -2031.321622 ┆ 0     ┆ A        │\n",
       "│ 6       ┆ 1.522247  ┆ 0.626331  ┆ 0.460844 ┆ -0.060739 ┆ 1487.444343  ┆ 1     ┆ A        │\n",
       "│ 7       ┆ 3.93548   ┆ 0.363229  ┆ 2.002222 ┆ -0.613627 ┆ -335.203183  ┆ 0     ┆ A        │\n",
       "│ …       ┆ …         ┆ …         ┆ …        ┆ …         ┆ …            ┆ …     ┆ …        │\n",
       "│ 99991   ┆ 3.808594  ┆ 0.693914  ┆ 6.727779 ┆ -0.781093 ┆ -868.307031  ┆ 2     ┆ C        │\n",
       "│ 99994   ┆ 6.246362  ┆ 0.99597   ┆ 3.468162 ┆ -0.699768 ┆ -145.471814  ┆ 1     ┆ C        │\n",
       "│ 99996   ┆ 0.520435  ┆ 0.758179  ┆ 0.680518 ┆ 0.788875  ┆ -3203.56896  ┆ 2     ┆ C        │\n",
       "│ 99997   ┆ 6.250958  ┆ 0.762393  ┆ 0.08691  ┆ 1.79754   ┆ 696.859327   ┆ 1     ┆ C        │\n",
       "│ 99998   ┆ 4.491091  ┆ 0.396969  ┆ 0.012585 ┆ 2.024051  ┆ -2468.859815 ┆ 2     ┆ C        │\n",
       "└─────────┴───────────┴───────────┴──────────┴───────────┴──────────────┴───────┴──────────┘"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Random Sample\n",
    "sa.sample(df, 0.6) # by ratio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (30_000, 8)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>row_num</th><th>uniform_1</th><th>uniform_2</th><th>exp</th><th>normal</th><th>fat_normal</th><th>flags</th><th>category</th></tr><tr><td>i64</td><td>f64</td><td>f64</td><td>f64</td><td>f64</td><td>f64</td><td>i32</td><td>str</td></tr></thead><tbody><tr><td>0</td><td>3.169978</td><td>0.183096</td><td>0.848878</td><td>-0.988939</td><td>369.76195</td><td>2</td><td>&quot;A&quot;</td></tr><tr><td>12</td><td>2.718784</td><td>0.236327</td><td>0.656341</td><td>2.042461</td><td>992.106646</td><td>0</td><td>&quot;A&quot;</td></tr><tr><td>13</td><td>5.688242</td><td>0.238128</td><td>1.989903</td><td>-1.890975</td><td>96.609098</td><td>1</td><td>&quot;A&quot;</td></tr><tr><td>16</td><td>10.630157</td><td>0.685417</td><td>2.040244</td><td>-0.411343</td><td>-80.440654</td><td>2</td><td>&quot;A&quot;</td></tr><tr><td>18</td><td>6.133318</td><td>0.868581</td><td>3.786928</td><td>-0.853489</td><td>-824.372864</td><td>1</td><td>&quot;A&quot;</td></tr><tr><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td></tr><tr><td>99976</td><td>7.686265</td><td>0.037178</td><td>9.872401</td><td>0.002709</td><td>1013.823443</td><td>2</td><td>&quot;C&quot;</td></tr><tr><td>99981</td><td>9.040585</td><td>0.272563</td><td>0.423536</td><td>-0.365252</td><td>-718.151462</td><td>1</td><td>&quot;C&quot;</td></tr><tr><td>99985</td><td>8.940385</td><td>0.856215</td><td>2.355023</td><td>0.609717</td><td>-34.944096</td><td>0</td><td>&quot;C&quot;</td></tr><tr><td>99986</td><td>6.501358</td><td>0.676297</td><td>1.185671</td><td>-0.284971</td><td>583.365443</td><td>1</td><td>&quot;C&quot;</td></tr><tr><td>99996</td><td>0.520435</td><td>0.758179</td><td>0.680518</td><td>0.788875</td><td>-3203.56896</td><td>2</td><td>&quot;C&quot;</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (30_000, 8)\n",
       "┌─────────┬───────────┬───────────┬──────────┬───────────┬─────────────┬───────┬──────────┐\n",
       "│ row_num ┆ uniform_1 ┆ uniform_2 ┆ exp      ┆ normal    ┆ fat_normal  ┆ flags ┆ category │\n",
       "│ ---     ┆ ---       ┆ ---       ┆ ---      ┆ ---       ┆ ---         ┆ ---   ┆ ---      │\n",
       "│ i64     ┆ f64       ┆ f64       ┆ f64      ┆ f64       ┆ f64         ┆ i32   ┆ str      │\n",
       "╞═════════╪═══════════╪═══════════╪══════════╪═══════════╪═════════════╪═══════╪══════════╡\n",
       "│ 0       ┆ 3.169978  ┆ 0.183096  ┆ 0.848878 ┆ -0.988939 ┆ 369.76195   ┆ 2     ┆ A        │\n",
       "│ 12      ┆ 2.718784  ┆ 0.236327  ┆ 0.656341 ┆ 2.042461  ┆ 992.106646  ┆ 0     ┆ A        │\n",
       "│ 13      ┆ 5.688242  ┆ 0.238128  ┆ 1.989903 ┆ -1.890975 ┆ 96.609098   ┆ 1     ┆ A        │\n",
       "│ 16      ┆ 10.630157 ┆ 0.685417  ┆ 2.040244 ┆ -0.411343 ┆ -80.440654  ┆ 2     ┆ A        │\n",
       "│ 18      ┆ 6.133318  ┆ 0.868581  ┆ 3.786928 ┆ -0.853489 ┆ -824.372864 ┆ 1     ┆ A        │\n",
       "│ …       ┆ …         ┆ …         ┆ …        ┆ …         ┆ …           ┆ …     ┆ …        │\n",
       "│ 99976   ┆ 7.686265  ┆ 0.037178  ┆ 9.872401 ┆ 0.002709  ┆ 1013.823443 ┆ 2     ┆ C        │\n",
       "│ 99981   ┆ 9.040585  ┆ 0.272563  ┆ 0.423536 ┆ -0.365252 ┆ -718.151462 ┆ 1     ┆ C        │\n",
       "│ 99985   ┆ 8.940385  ┆ 0.856215  ┆ 2.355023 ┆ 0.609717  ┆ -34.944096  ┆ 0     ┆ C        │\n",
       "│ 99986   ┆ 6.501358  ┆ 0.676297  ┆ 1.185671 ┆ -0.284971 ┆ 583.365443  ┆ 1     ┆ C        │\n",
       "│ 99996   ┆ 0.520435  ┆ 0.758179  ┆ 0.680518 ┆ 0.788875  ┆ -3203.56896 ┆ 2     ┆ C        │\n",
       "└─────────┴───────────┴───────────┴──────────┴───────────┴─────────────┴───────┴──────────┘"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sa.sample(df, 30_000) # by count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (3, 2)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>flags</th><th>len</th></tr><tr><td>i32</td><td>u32</td></tr></thead><tbody><tr><td>0</td><td>33465</td></tr><tr><td>1</td><td>33331</td></tr><tr><td>2</td><td>33204</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (3, 2)\n",
       "┌───────┬───────┐\n",
       "│ flags ┆ len   │\n",
       "│ ---   ┆ ---   │\n",
       "│ i32   ┆ u32   │\n",
       "╞═══════╪═══════╡\n",
       "│ 0     ┆ 33465 │\n",
       "│ 1     ┆ 33331 │\n",
       "│ 2     ┆ 33204 │\n",
       "└───────┴───────┘"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.group_by(\"flags\").len().sort(\"flags\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (3, 2)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>flags</th><th>len</th></tr><tr><td>i32</td><td>u32</td></tr></thead><tbody><tr><td>0</td><td>16732</td></tr><tr><td>1</td><td>33331</td></tr><tr><td>2</td><td>33204</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (3, 2)\n",
       "┌───────┬───────┐\n",
       "│ flags ┆ len   │\n",
       "│ ---   ┆ ---   │\n",
       "│ i32   ┆ u32   │\n",
       "╞═══════╪═══════╡\n",
       "│ 0     ┆ 16732 │\n",
       "│ 1     ┆ 33331 │\n",
       "│ 2     ┆ 33204 │\n",
       "└───────┴───────┘"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Downsample on one group\n",
    "sa1 = sa.downsample(\n",
    "    df, \n",
    "    (pl.col(\"flags\") == 0, 0.5)\n",
    ")\n",
    "sa1.group_by(\"flags\").len().sort(\"flags\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (3, 2)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>flags</th><th>len</th></tr><tr><td>i32</td><td>u32</td></tr></thead><tbody><tr><td>0</td><td>16732</td></tr><tr><td>1</td><td>9999</td></tr><tr><td>2</td><td>13281</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (3, 2)\n",
       "┌───────┬───────┐\n",
       "│ flags ┆ len   │\n",
       "│ ---   ┆ ---   │\n",
       "│ i32   ┆ u32   │\n",
       "╞═══════╪═══════╡\n",
       "│ 0     ┆ 16732 │\n",
       "│ 1     ┆ 9999  │\n",
       "│ 2     ┆ 13281 │\n",
       "└───────┴───────┘"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Downsample on multiple groups\n",
    "sa2 = sa.downsample(\n",
    "    df, \n",
    "    (pl.col(\"flags\") == 0, 0.5),\n",
    "    (pl.col(\"flags\") == 1, 0.3),\n",
    "    (pl.col(\"flags\") == 2, 0.4),\n",
    ")\n",
    "sa2.group_by(\"flags\").len().sort(\"flags\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (3, 2)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>category</th><th>len</th></tr><tr><td>str</td><td>u32</td></tr></thead><tbody><tr><td>&quot;A&quot;</td><td>30000</td></tr><tr><td>&quot;B&quot;</td><td>30000</td></tr><tr><td>&quot;C&quot;</td><td>40000</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (3, 2)\n",
       "┌──────────┬───────┐\n",
       "│ category ┆ len   │\n",
       "│ ---      ┆ ---   │\n",
       "│ str      ┆ u32   │\n",
       "╞══════════╪═══════╡\n",
       "│ A        ┆ 30000 │\n",
       "│ B        ┆ 30000 │\n",
       "│ C        ┆ 40000 │\n",
       "└──────────┴───────┘"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.group_by(\"category\").len().sort(\"category\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (3, 2)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>category</th><th>len</th></tr><tr><td>str</td><td>u32</td></tr></thead><tbody><tr><td>&quot;A&quot;</td><td>30000</td></tr><tr><td>&quot;B&quot;</td><td>30000</td></tr><tr><td>&quot;C&quot;</td><td>30000</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (3, 2)\n",
       "┌──────────┬───────┐\n",
       "│ category ┆ len   │\n",
       "│ ---      ┆ ---   │\n",
       "│ str      ┆ u32   │\n",
       "╞══════════╪═══════╡\n",
       "│ A        ┆ 30000 │\n",
       "│ B        ┆ 30000 │\n",
       "│ C        ┆ 30000 │\n",
       "└──────────┴───────┘"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Volume neutral by each category, will take the greatest possible value so that we get neutral volume.\n",
    "vn = sa.volume_neutral(\n",
    "    df,\n",
    "    by = pl.col(\"category\"),\n",
    ")\n",
    "vn.group_by(\"category\").len().sort(\"category\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (3, 2)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>category</th><th>len</th></tr><tr><td>str</td><td>u32</td></tr></thead><tbody><tr><td>&quot;A&quot;</td><td>10000</td></tr><tr><td>&quot;B&quot;</td><td>10000</td></tr><tr><td>&quot;C&quot;</td><td>10000</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (3, 2)\n",
       "┌──────────┬───────┐\n",
       "│ category ┆ len   │\n",
       "│ ---      ┆ ---   │\n",
       "│ str      ┆ u32   │\n",
       "╞══════════╪═══════╡\n",
       "│ A        ┆ 10000 │\n",
       "│ B        ┆ 10000 │\n",
       "│ C        ┆ 10000 │\n",
       "└──────────┴───────┘"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Volume neutral (10_000) by each category\n",
    "vn = sa.volume_neutral(\n",
    "    df,\n",
    "    by = pl.col(\"category\"),\n",
    "    target_volume = 10_000\n",
    ")\n",
    "vn.group_by(\"category\").len().sort(\"category\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (3, 2)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>category</th><th>len</th></tr><tr><td>str</td><td>u32</td></tr></thead><tbody><tr><td>&quot;A&quot;</td><td>10000</td></tr><tr><td>&quot;B&quot;</td><td>4285</td></tr><tr><td>&quot;C&quot;</td><td>5715</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (3, 2)\n",
       "┌──────────┬───────┐\n",
       "│ category ┆ len   │\n",
       "│ ---      ┆ ---   │\n",
       "│ str      ┆ u32   │\n",
       "╞══════════╪═══════╡\n",
       "│ A        ┆ 10000 │\n",
       "│ B        ┆ 4285  │\n",
       "│ C        ┆ 5715  │\n",
       "└──────────┴───────┘"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Volume neutral (10_000) by a more complicated condition\n",
    "vn = sa.volume_neutral(\n",
    "    df,\n",
    "    by = pl.col(\"category\") == \"A\",\n",
    "    target_volume = 10_000\n",
    ") # This makes sense because count for B + count for C = 10_000\n",
    "vn.group_by(\"category\").len().sort(\"category\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (9, 3)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>category</th><th>flags</th><th>len</th></tr><tr><td>str</td><td>i32</td><td>u32</td></tr></thead><tbody><tr><td>&quot;A&quot;</td><td>0</td><td>9960</td></tr><tr><td>&quot;A&quot;</td><td>1</td><td>9960</td></tr><tr><td>&quot;A&quot;</td><td>2</td><td>9960</td></tr><tr><td>&quot;B&quot;</td><td>0</td><td>9962</td></tr><tr><td>&quot;B&quot;</td><td>1</td><td>9962</td></tr><tr><td>&quot;B&quot;</td><td>2</td><td>9962</td></tr><tr><td>&quot;C&quot;</td><td>0</td><td>13223</td></tr><tr><td>&quot;C&quot;</td><td>1</td><td>13223</td></tr><tr><td>&quot;C&quot;</td><td>2</td><td>13223</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (9, 3)\n",
       "┌──────────┬───────┬───────┐\n",
       "│ category ┆ flags ┆ len   │\n",
       "│ ---      ┆ ---   ┆ ---   │\n",
       "│ str      ┆ i32   ┆ u32   │\n",
       "╞══════════╪═══════╪═══════╡\n",
       "│ A        ┆ 0     ┆ 9960  │\n",
       "│ A        ┆ 1     ┆ 9960  │\n",
       "│ A        ┆ 2     ┆ 9960  │\n",
       "│ B        ┆ 0     ┆ 9962  │\n",
       "│ B        ┆ 1     ┆ 9962  │\n",
       "│ B        ┆ 2     ┆ 9962  │\n",
       "│ C        ┆ 0     ┆ 13223 │\n",
       "│ C        ┆ 1     ┆ 13223 │\n",
       "│ C        ┆ 2     ┆ 13223 │\n",
       "└──────────┴───────┴───────┘"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Volume neutral sample with a control level. Volume neutral happens under the category level, meaning\n",
    "# the volume for each flag in each category is neutral.\n",
    "vn = sa.volume_neutral(\n",
    "    df,\n",
    "    by = pl.col(\"flags\"),\n",
    "    control = pl.col(\"category\")\n",
    ") \n",
    "vn.group_by([\"category\", \"flags\"]).len().sort([\"category\", \"flags\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (9, 3)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>category</th><th>flags</th><th>len</th></tr><tr><td>str</td><td>i32</td><td>u32</td></tr></thead><tbody><tr><td>&quot;A&quot;</td><td>0</td><td>9960</td></tr><tr><td>&quot;A&quot;</td><td>1</td><td>9960</td></tr><tr><td>&quot;A&quot;</td><td>2</td><td>9960</td></tr><tr><td>&quot;B&quot;</td><td>0</td><td>9962</td></tr><tr><td>&quot;B&quot;</td><td>1</td><td>9962</td></tr><tr><td>&quot;B&quot;</td><td>2</td><td>9962</td></tr><tr><td>&quot;C&quot;</td><td>0</td><td>10000</td></tr><tr><td>&quot;C&quot;</td><td>1</td><td>10000</td></tr><tr><td>&quot;C&quot;</td><td>2</td><td>10000</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (9, 3)\n",
       "┌──────────┬───────┬───────┐\n",
       "│ category ┆ flags ┆ len   │\n",
       "│ ---      ┆ ---   ┆ ---   │\n",
       "│ str      ┆ i32   ┆ u32   │\n",
       "╞══════════╪═══════╪═══════╡\n",
       "│ A        ┆ 0     ┆ 9960  │\n",
       "│ A        ┆ 1     ┆ 9960  │\n",
       "│ A        ┆ 2     ┆ 9960  │\n",
       "│ B        ┆ 0     ┆ 9962  │\n",
       "│ B        ┆ 1     ┆ 9962  │\n",
       "│ B        ┆ 2     ┆ 9962  │\n",
       "│ C        ┆ 0     ┆ 10000 │\n",
       "│ C        ┆ 1     ┆ 10000 │\n",
       "│ C        ┆ 2     ┆ 10000 │\n",
       "└──────────┴───────┴───────┘"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# We may not meet the target volume for all categories.\n",
    "vn = sa.volume_neutral(\n",
    "    df,\n",
    "    by = pl.col(\"flags\"),\n",
    "    control = pl.col(\"category\"),\n",
    "    target_volume= 10_000\n",
    ") # \n",
    "vn.group_by([\"category\", \"flags\"]).len().sort([\"category\", \"flags\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Splitting\n",
    "\n",
    "Split by ratios."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(100000, 8)\n",
      "(60000, 8)\n",
      "(40000, 8)\n"
     ]
    }
   ],
   "source": [
    "print(df.shape)\n",
    "train, test = sa.split_by_ratio(\n",
    "    df,\n",
    "    split_ratio = 0.6\n",
    ")\n",
    "print(train.shape)\n",
    "print(test.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(100000, 8)\n",
      "(25000, 8)\n",
      "(40000, 8)\n",
      "(10000, 8)\n",
      "(25000, 8)\n"
     ]
    }
   ],
   "source": [
    "print(df.shape)\n",
    "for frame in sa.split_by_ratio(df, split_ratio = [0.25, 0.4, 0.10, 0.25]):\n",
    "    print(frame.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
